-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model EMA (Exponential Moving Average) #114
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, let's also add it do the docs (callback README)
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found.
Additional details and impacted files@@ Coverage Diff @@
## main #114 +/- ##
=======================================
Coverage ? 97.06%
=======================================
Files ? 141
Lines ? 6133
Branches ? 0
=======================================
Hits ? 5953
Misses ? 180
Partials ? 0 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This pull request introduces a new Exponential Moving Average (EMA) callback for model training, along with its integration and corresponding unit tests. The main changes include the addition of the
EMACallback
class, its registration, and tests to ensure its functionality.New EMA Callback
luxonis_train/callbacks/ema.py
: Added theEMACallback
class, which includes methods for initializing the EMA, updating EMA weights, and handling checkpoint saving/loading. This class helps in maintaining a moving average of model parameters to potentially improve model performance.Integration
luxonis_train/callbacks/__init__.py
: Registered theEMACallback
in theCALLBACKS
registry and added it to the__all__
list for proper module export. [1] [2] [3]Unit Tests
tests/unittests/test_callbacks/test_ema.py
: Added unit tests for theEMACallback
to verify its initialization, batch update, checkpoint saving/loading, and validation epoch handling.